# -*- coding: utf-8 -*-
# 加入LoRA微调

from transformers import AutoTokenizer, AutoModelForMaskedLM
from peft import LoraConfig, get_peft_model

model_dir = "/root/autodl-tmp//models/pretrained/google-bert/bert-base-chinese"

# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# 加载BERT模型（MLM任务）
model = AutoModelForMaskedLM.from_pretrained(model_dir)

# 配置LoRA
lora_config = LoraConfig(
    task_type="TOKEN_CLS",   # BERT是token-level任务，可以用TOKEN_CLS
    r=8,                     # rank
    lora_alpha=32,           # scaling
    lora_dropout=0.1,        # dropout
    target_modules=["query", "value"]  # 指定LoRA插入的模块，BERT常用query和value
)

# 注入LoRA
model = get_peft_model(model, lora_config)

print(model)
